/**
 * 
 */
package edu.berkeley.nlp.scripts;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;

import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.TreeAnnotations;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.syntax.Trees.PennTreeReader;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.parser.EnglishPennTreebankParseEvaluator;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Filter;
import edu.berkeley.nlp.util.Numberer;

/**
 * Takes a treebank with observed split categories and puts it into our format
 * @author petrov
 *
 */
public class GermanSharedTask {

	Numberer tagNumberer;
	List<Numberer> substateNumberers;
	
	public Grammar extractGrammar(List<Tree<String>> trainTrees){
		tagNumberer = Numberer.getGlobalNumberer("tags");
		substateNumberers = new ArrayList<Numberer>();

		short[] numSubStates = countSymbols(trainTrees);
		
		List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
		StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, numSubStates, false, tagNumberer);
		
		Grammar grammar = createGrammar(stateSetTrees, trainTrees, numSubStates);
		
		return grammar;
	}


	private void checkGrammar(Grammar grammar, List<Tree<String>> trainTrees, List<Tree<String>> goldTrees) {
    EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String>(new HashSet<String>(Arrays.asList(new String[] {"ROOT","PSEUDO"})), new HashSet<String>(Arrays.asList(new String[] {"''", "``", ".", ":", ","})));

    List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
		StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, grammar.numSubStates, false, tagNumberer);
    
		int index = 0;
		for (Tree<StateSet> stateSetTree : stateSetTrees){
			Tree<String> goldTree = goldTrees.get(index++);
			while (goldTree.getYield().size()!=stateSetTree.getYield().size()&&index<=goldTrees.size()){
				goldTree = goldTrees.get(index++);
			}
			
			List<String> goldPOS = goldTree.getPreTerminalYield();
			
			Tree<String> labeledTree = guessGF(stateSetTree, grammar, goldPOS);
	    Tree<String> debinarizedTree = Trees.spliceNodes(labeledTree, new Filter<String>() {
	      public boolean accept(String s) {
	        return s.startsWith("@");
	      }
	    });
	    
	    Tree<String> goldDebTree = Trees.spliceNodes(goldTree, new Filter<String>() {
	      public boolean accept(String s) {
	        return s.startsWith("@");
	      }
	    });
	    eval.evaluate(goldDebTree, debinarizedTree);
			int t = 1;
			t++;
		}
    eval.display(true);
	}
	
	private void labelTrees(Grammar grammar, List<Tree<String>> trainTrees, List<List<String>> goldPOStags) {
    List<Tree<String>> trainTreesNoGF = stripOffGF(trainTrees);
		StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoGF, grammar.numSubStates, false, tagNumberer);
    
		int index = 0;
		for (Tree<StateSet> stateSetTree : stateSetTrees){			
			List<String> goldPOS = goldPOStags.get(index++);
			
			Tree<String> labeledTree = guessGF(stateSetTree, grammar, goldPOS);

	    Tree<String> debinarizedTree = Trees.spliceNodes(labeledTree, new Filter<String>() {
	      public boolean accept(String s) {
	        return s.startsWith("@");
	      }
	    });

			System.out.println(debinarizedTree+"\n");
		}

	}
	
	/**
	 * @param stateSetTree
	 * @param grammar
	 * @param goldPOS
	 * @return
	 */
	private Tree<String> guessGF(Tree<StateSet> stateSetTree, Grammar grammar, List<String> goldPOS) {
		doInsideScores(stateSetTree, grammar, goldPOS); 
		return extractBestViterbiDerivation(grammar,stateSetTree,0);
	}


	private List<Tree<String>> stripOffGF(List<Tree<String>> trainTrees) {
		List<Tree<String>> trainTreesNoGF = new ArrayList<Tree<String>>(trainTrees.size());
		for (Tree<String> tree : trainTrees){
			trainTreesNoGF.add(tree.shallowClone());
		}
		for (Tree<String> tree : trainTreesNoGF){
			for (Tree<String> node : tree.getPostOrderTraversal()){
				if (tree.isLeaf()) continue;
				String label = node.getLabel();
				int cutIndex = label.indexOf('-');
				if (cutIndex!=-1) label = label.substring(0,cutIndex);
				node.setLabel(label);
			}
		}
		return trainTreesNoGF;
	}


	private Grammar createGrammar(StateSetTreeList stateSetTrees, List<Tree<String>> trainTrees, short[] numSubStates) {
		Grammar grammar = new Grammar(numSubStates, false, new NoSmoothing(), null, -1);
		int index = 0;
		for (Tree<StateSet> stateSetTree : stateSetTrees){
			Tree<String> tree = trainTrees.get(index++);
			setScores(stateSetTree, tree);
			grammar.tallyStateSetTree(stateSetTree, grammar);
		}
		grammar.optimize(0); // M Step
		return grammar;
	}


	private void setScores(Tree<StateSet> stateSetTree, Tree<String> tree) {
		if (tree.isLeaf()) return;
		String[] labels = splitLabel(tree.getLabel());
		StateSet stateSet = stateSetTree.getLabel();
		int substate = substateNumberers.get(stateSet.getState()).number(labels[1]);
		stateSet.setIScore(substate, 1.0);
		stateSet.setIScale(0);
		stateSet.setOScore(substate, 1.0);
		stateSet.setOScale(0);
		
		int nChildren = tree.getChildren().size();
		if (nChildren != stateSetTree.getChildren().size()) System.err.println("Mismatch!");
		for (int i=0; i<nChildren; i++){
			setScores(stateSetTree.getChildren().get(i), tree.getChildren().get(i));
		}
	}


	private short[] countSymbols(List<Tree<String>> trainTrees) {
		for (Tree<String> tree : trainTrees){
			processTree(tree);
		}
		short[] numSubStates = new short[tagNumberer.total()];
		for (int substate=0; substate<numSubStates.length; substate++){
			numSubStates[substate] = (short)substateNumberers.get(substate).total();
		}
		return numSubStates;
	}


	private void processTree(Tree<String> tree) {
		String[] labels = splitLabel(tree.getLabel());
		int state = tagNumberer.number(labels[0]);

		if (state >= substateNumberers.size()) {
			substateNumberers.add(new Numberer());
		}
		substateNumberers.get(state).number(labels[1]);
		
		for (Tree<String> child : tree.getChildren()){
			if (!child.isLeaf()) processTree(child);
		}
		
	}

	/**
	 * @param label
	 * @return
	 */
	private String[] splitLabel(String label) {
		String[] labels = label.split("-");
		if (labels.length==1) labels = new String[]{labels[0],""};
		return labels;	
	}


	
  Tree<String> extractBestViterbiDerivation(Grammar grammar, Tree<StateSet> tree, int substate){
  	if (tree.isLeaf()) return new Tree<String>(tree.getLabel().getWord());
  	if (substate==-1) substate=0;
  	if (tree.isPreTerminal()){
  		ArrayList<Tree<String>> child = new ArrayList<Tree<String>>();
  		child.add(extractBestViterbiDerivation(grammar, tree.getChildren().get(0),-1));
  		int state = tree.getLabel().getState();
  		String goalStr = (String)tagNumberer.object(state);
  		String gfStr = (String)substateNumberers.get(state).object(substate);
  		if (!gfStr.equals("")) goalStr = goalStr + "-" + gfStr;
  		return new Tree<String>(goalStr, child);
  	}
  	
  	StateSet node = tree.getLabel();
  	short pState = node.getState();
		
		ArrayList<Tree<String>> newChildren = new ArrayList<Tree<String>>();
		List<Tree<StateSet>> children = tree.getChildren();
		
		double myScore = node.getIScore(substate);
		if (myScore==Double.NEGATIVE_INFINITY){
			myScore = DoubleArrays.max(node.getIScores());
			substate = DoubleArrays.argMax(node.getIScores());
		}
		switch (children.size()) {
		case 1:
			StateSet child = children.get(0).getLabel();
			short cState = child.getState();
			int nChildStates = child.numSubStates();
			double[][] uscores = grammar.getUnaryScore(pState,cState);
			int childIndex = -1;
			for (int j = 0; j < nChildStates; j++) {
				if (childIndex != -1) break;
				if (uscores[j]!=null) { 
					double cS = child.getIScore(j);
					if (cS==0) continue;
					double rS = uscores[j][substate]; // rule score
					if (rS==0) continue;
					double res = rS * cS;
					if (matches(res,myScore)){ 
						childIndex = j;
					}
				}
			}
			newChildren.add(extractBestViterbiDerivation(grammar, children.get(0), childIndex));
			break;
		case 2:
			StateSet leftChild = children.get(0).getLabel();
			StateSet rightChild = children.get(1).getLabel();
			int nLeftChildStates = leftChild.numSubStates();
			int nRightChildStates = rightChild.numSubStates();
			short lState = leftChild.getState();
			short rState = rightChild.getState();
			double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
			int lChildIndex = -1, rChildIndex = -1;
			for (int j = 0; j < nLeftChildStates; j++) {
				if (lChildIndex!=-1 && rChildIndex!=-1) break;
				double lcS = leftChild.getIScore(j);
				if (lcS==0) continue;
				for (int k = 0; k < nRightChildStates; k++) {
					if (lChildIndex!=-1 && rChildIndex!=-1) break;
					double rcS = rightChild.getIScore(k);
					if (rcS==0) continue;
					if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
						double rS = bscores[j][k][substate];
						if (rS==0) continue;
						double res = rS * lcS * rcS;
						if (matches(myScore,res)){
							lChildIndex = j;
							rChildIndex = k;
						}
					}
				}
			}
			newChildren.add(extractBestViterbiDerivation(grammar, children.get(0), lChildIndex));
			newChildren.add(extractBestViterbiDerivation(grammar, children.get(1), rChildIndex));
			break;
		default:
			throw new Error ("Malformed tree: more than two children");
		}
		
		int state = node.getState();
  	String parentString = (String)tagNumberer.object(state);
		if (parentString.endsWith("^g")) parentString = parentString.substring(0,parentString.length()-2);
		String gfStr = (String)substateNumberers.get(state).object(substate);
		if (!gfStr.equals("")) parentString = parentString + "-" + gfStr;

		return new Tree<String>(parentString, newChildren);
  }

  protected boolean matches(double x, double y) {
    return (Math.abs(x - y) / (Math.abs(x) + Math.abs(y) + 1e-10) < 1.0e-4);
  }

  void doInsideScores(Tree<StateSet> tree, Grammar grammar, List<String> goldPOS) {
    if (tree.isLeaf()){
      return;
    }
    List<Tree<StateSet>> children = tree.getChildren();
    for (Tree<StateSet> child : children) {
      if (!child.isLeaf()) doInsideScores(child, grammar, goldPOS);
    }
    StateSet parent = tree.getLabel();
    short pState = parent.getState();
    int nParentStates = parent.numSubStates();
    if (tree.isPreTerminal()) {
      // Plays a role similar to initializeChart()
    	String POS = goldPOS.get(parent.from);
    	String[] labels = splitLabel(POS);
    	int substate = 0;
    	if (pState<grammar.numStates){
	    	substate = substateNumberers.get(pState).number(labels[1]);
	    	if (substate>=grammar.numSubStates[pState]){
	    		System.err.println("Have never seen this POS: "+POS);
	    		substate=0;
	    	}
    	} else {
    		parent = new StateSet((short)(grammar.numStates-1), (short)1);
    		tree.setLabel(parent);
    	}
    	parent.setIScore(substate, 1.0);
      parent.scaleIScores(0);
    } else {
      switch (children.size()) {
      case 0:
        break;
      case 1:
        StateSet child = children.get(0).getLabel();
        short cState = child.getState();
        int nChildStates = child.numSubStates();
        double[][] uscores = grammar.getUnaryScore(pState,cState);
        double[] iScores = new double[nParentStates];
        boolean foundOne = false;
        for (int j = 0; j < nChildStates; j++) {
          if (uscores[j]!=null) { //check whether one of the parents can produce this child
            double cS = child.getIScore(j);
            if (cS==0) continue;
            for (int i = 0; i < nParentStates; i++) {
              double rS = uscores[j][i]; // rule score
              if (rS==0) continue;
              double res = rS * cS;
              /*if (res == 0) {
                System.out.println("Prevented an underflow: rS "+rS+" cS "+cS);
                res = Double.MIN_VALUE;
              }*/
              iScores[i] += res;
              foundOne = true;
            }
          }
        }

        parent.setIScores(iScores);
        parent.scaleIScores(child.getIScale());
        break;
      case 2:
        StateSet leftChild = children.get(0).getLabel();
        StateSet rightChild = children.get(1).getLabel();
        int nLeftChildStates = leftChild.numSubStates();
        int nRightChildStates = rightChild.numSubStates();
        short lState = leftChild.getState();
        short rState = rightChild.getState();
        double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
        double[] iScores2 = new double[nParentStates];
        boolean foundOne2 = false;
        for (int j = 0; j < nLeftChildStates; j++) {
          double lcS = leftChild.getIScore(j);
          if (lcS==0) continue;
          for (int k = 0; k < nRightChildStates; k++) {
            double rcS = rightChild.getIScore(k);
            if (rcS==0) continue;
            if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
	          	for (int i = 0; i < nParentStates; i++) {
	              double rS = bscores[j][k][i];
	              if (rS==0) continue;
	              double res = rS * lcS * rcS;
	              /*if (res == 0) {
	                System.out.println("Prevented an underflow: rS "+rS+" lcS "+lcS+" rcS "+rcS);
	                res = Double.MIN_VALUE;
	              }*/
	              iScores2[i] += res;
	              foundOne2 = true;
	            }
            }
          }
        }

        parent.setIScores(iScores2);
        parent.scaleIScores(leftChild.getIScale()+rightChild.getIScale());
        break;
      default:
        throw new Error("Malformed tree: more than two children");
      }
    }
  }

		private static List<Tree<String>> loadTrees(String inputFile) {
		InputStreamReader inputData = null;
		try {
			inputData = new InputStreamReader(new FileInputStream(inputFile), "UTF-8");
		} catch (UnsupportedEncodingException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		PennTreeReader treeReader = new PennTreeReader(inputData);
		
		List<Tree<String>> trainTrees = new ArrayList<Tree<String>>();
		Tree<String> tree = null;
		while(treeReader.hasNext()){
			tree = treeReader.next();  	
//	  	trainTrees.add(TreeAnnotations.processTree(tree, 1, 0, Binarization.LEFT, false, false, false));
			trainTrees.add(tree);
		}
		return trainTrees;
	}
	
	public static void main(String[] args) {
		String inputFile = args[0];
		List<Tree<String>> trainTrees = loadTrees(inputFile);
		
		GermanSharedTask grEx = new GermanSharedTask();
		Grammar grammar = grEx.extractGrammar(trainTrees);
		
		inputFile = "/Users/petrov/Data/german_st/tueba/tueba_tmp";
		List<Tree<String>> testTrees = loadTrees(inputFile);
		inputFile = "/Users/petrov/Data/german_st/tueba/data02.mrg";
		List<Tree<String>> goldTrees = loadTrees(inputFile);
		List<List<String>> goldPOS = new ArrayList<List<String>>(goldTrees.size());
		for (Tree<String> t : goldTrees){
			goldPOS.add(t.getPreTerminalYield());
		}
		grEx.checkGrammar(grammar, testTrees, goldTrees);
//		grEx.labelTrees(grammar, testTrees, goldPOS);
	}
	
}
